{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torchvision import datasets, transforms\n",
    "import argparse\n",
    "import time\n",
    "import os\n",
    "\n",
    "#setup training parameters\n",
    "parser = argparse.ArgumentParser(description='PyTorch MNIST Training')\n",
    "parser.add_argument('--batch-size', type=int, default=128, metavar='N',\n",
    "                    help='input batch size for training (default: 128)')\n",
    "parser.add_argument('--test-batch-size', type=int, default=128, metavar='N',\n",
    "                    help='input batch size for testing (default: 128)')\n",
    "parser.add_argument('--epochs', type=int, default=5, metavar='N',\n",
    "                    help='number of epochs to train')\n",
    "parser.add_argument('--lr', type=float, default=0.01, metavar='LR',\n",
    "                    help='learning rate')\n",
    "parser.add_argument('--no-cuda', action='store_true', default=False,\n",
    "                    help='disables CUDA training')\n",
    "parser.add_argument('--seed', type=int, default=1, metavar='S',\n",
    "                    help='random seed (default: 1)')\n",
    "parser.add_argument('--model-dir', default='./model-mnist-cnn',\n",
    "                    help='directory of model for saving checkpoint')\n",
    "parser.add_argument('--load-model', action='store_true', default=False,\n",
    "                    help='load model or not')\n",
    "\n",
    "args = parser.parse_args(args=[]) \n",
    "\n",
    "if not os.path.exists(args.model_dir):\n",
    "    os.makedirs(args.model_dir)\n",
    "        \n",
    "# Judge cuda is available or not\n",
    "use_cuda = not args.no_cuda and torch.cuda.is_available()\n",
    "#device = torch.device(\"cuda\" if use_cuda else \"cpu\")\n",
    "device = torch.device(\"cpu\")\n",
    "\n",
    "torch.manual_seed(args.seed)\n",
    "kwargs = {'num_workers': 1, 'pin_memory': True} if use_cuda else {}\n",
    "\n",
    "# Setup data loader\n",
    "transform=transforms.Compose([\n",
    "        transforms.ToTensor(),\n",
    "        transforms.Normalize((0.1307,), (0.3081,))\n",
    "        ])\n",
    "trainset = datasets.MNIST('../data', train=True, download=True,\n",
    "                   transform=transform)\n",
    "testset = datasets.MNIST('../data', train=False,\n",
    "                   transform=transform)\n",
    "train_loader = torch.utils.data.DataLoader(trainset,batch_size=args.batch_size, shuffle=True,**kwargs)\n",
    "test_loader = torch.utils.data.DataLoader(testset,batch_size=args.test_batch_size, shuffle=False, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define CNN\n",
    "class Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(Net, self).__init__()\n",
    "        # in_channels:1  out_channels:32  kernel_size:3  stride:1\n",
    "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
    "        # in_channels:32  out_channels:64  kernel_size:3  stride:1\n",
    "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
    "        self.fc1 = nn.Linear(9216, 128)\n",
    "        self.fc2 = nn.Linear(128, 10)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = self.conv1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.conv2(x)\n",
    "        x = F.relu(x)\n",
    "        x = F.max_pool2d(x, 2)\n",
    "        x = torch.flatten(x, 1)\n",
    "        x = self.fc1(x)\n",
    "        x = F.relu(x)\n",
    "        x = self.fc2(x)\n",
    "        output = F.log_softmax(x, dim=1)\n",
    "        return output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Train function\n",
    "def train(args, model, device, train_loader, optimizer, epoch):\n",
    "    model.train()\n",
    "    for batch_idx, (data, target) in enumerate(train_loader):\n",
    "        data, target = data.to(device), target.to(device)\n",
    "        \n",
    "        #clear gradients\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        #compute loss\n",
    "        loss = F.cross_entropy(model(data), target)\n",
    "        \n",
    "        #get gradients and update\n",
    "        loss.backward()\n",
    "        optimizer.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/LOCAL2/gjin/anaconda3/lib/python3.7/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='sum' instead.\n",
      "  warnings.warn(warning.format(ret))\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "trn_loss: 0.1022, trn_acc: 96.94%, test_loss: 0.1055, test_acc: 96.89%\n"
     ]
    }
   ],
   "source": [
    "# Predict function\n",
    "def eval_test(model, device, test_loader):\n",
    "    model.eval()\n",
    "    test_loss = 0\n",
    "    correct = 0\n",
    "    with torch.no_grad():\n",
    "        for data, target in test_loader:\n",
    "            data, target = data.to(device), target.to(device)\n",
    "            output = model(data)\n",
    "            test_loss += F.cross_entropy(output, target, size_average=False).item()\n",
    "            pred = output.max(1, keepdim=True)[1]\n",
    "            correct += pred.eq(target.view_as(pred)).sum().item()\n",
    "    test_loss /= len(test_loader.dataset)\n",
    "    test_accuracy = correct / len(test_loader.dataset)\n",
    "    return test_loss, test_accuracy\n",
    "\n",
    "# Main function, train the initial model or load the model\n",
    "def main():\n",
    "    model = Net().to(device)\n",
    "    optimizer = optim.SGD(model.parameters(), lr=args.lr)\n",
    "    \n",
    "    if args.load_model:\n",
    "        # Load model\n",
    "        model.load_state_dict(torch.load(os.path.join(args.model_dir, 'final_model.pt')))\n",
    "        trnloss, trnacc = eval_test(model, device, train_loader)\n",
    "        tstloss, tstacc = eval_test(model, device, test_loader)\n",
    "        print('trn_loss: {:.4f}, trn_acc: {:.2f}%'.format(trnloss, 100. * trnacc), end=', ')\n",
    "        print('test_loss: {:.4f}, test_acc: {:.2f}%'.format(tstloss, 100. * tstacc))\n",
    "        \n",
    "    else:\n",
    "        # Train initial model\n",
    "        for epoch in range(1, args.epochs + 1):\n",
    "            start_time = time.time()\n",
    "\n",
    "            #training\n",
    "            train(args, model, device, train_loader, optimizer, epoch)\n",
    "\n",
    "            #get trnloss and testloss\n",
    "            trnloss, trnacc = eval_test(model, device, train_loader)\n",
    "            tstloss, tstacc = eval_test(model, device, test_loader)\n",
    "\n",
    "            #print trnloss and testloss\n",
    "            print('Epoch '+str(epoch)+': '+str(int(time.time()-start_time))+'s', end=', ')\n",
    "            print('trn_loss: {:.4f}, trn_acc: {:.2f}%'.format(trnloss, 100. * trnacc), end=', ')\n",
    "            print('test_loss: {:.4f}, test_acc: {:.2f}%'.format(tstloss, 100. * tstacc))\n",
    "        \n",
    "        #save model\n",
    "        torch.save(model.state_dict(), os.path.join(args.model_dir, 'final_model.pt'))\n",
    "\n",
    "if __name__ == '__main__':\n",
    "    main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1.8.1+cu111\n"
     ]
    }
   ],
   "source": [
    "print(torch.__version__)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
